import math
import matplotlib.pyplot as plt
import mindspore.dataset as ds

ds.config.set_seed(1)

DATA_DIR_CIFAR10 = './datasets/cifar10/cifar-10-batches-bin/'

# dataset_cifar10 = ds.Cifar100Dataset(DATA_DIR_CIFAR10, num_samples=4)
# dataset_cifar10 = ds.Cifar100Dataset(DATA_DIR_CIFAR10, num_samples=5)
dataset_cifar10 = ds.Cifar10Dataset(DATA_DIR_CIFAR10, sampler=ds.SequentialSampler(num_samples=5))

# print("aaa", type(dataset_cifar10.create_dict_iterator()))


def plt_result(dataset, row):
    num = 1
    for data in dataset.create_dict_iterator():
        print('Image shape:', data['image'].shape, ',label:', data['label'])
        plt.subplot(row, math.ceil(dataset.get_dataset_size() / row), num)
        image = data['image'].asnumpy()
        plt.title(data['label'])
        plt.imshow(image, interpolation="None")
        # plt.imshow(image)
        num += 1
    plt.show()

plt_result(dataset_cifar10, 1)
print(55)
